"""
Training script for Sparse Autoencoder in JAX with model saving and checkpointing.
"""
import jax
import jax.numpy as jnp
import numpy as np
import optax
import time
import os
from functools import partial
from transformers import AutoModelForCausalLM

# Import the JAX implementation of Sparse Autoencoder
from sae_jax import (
    SparseAutoencoder, 
    train, 
)

# Import saving and loading utilities
from sae_save_load import (
    save_model, 
    load_model, 
    save_checkpoint, 
    load_checkpoint,
    save_metadata
)

def main(args):
    """
    Train a Sparse Autoencoder and analyze results.
    
    Args:
        args: Command line arguments
    """
    # Create output directory
    output_dir = f"~/{args.model_name.split('/')[-1].lower()}-sae/"
    os.makedirs(output_dir, exist_ok=True)
    
    # Load data first to get embed_dim
    print("Loading input data...")
    inputs_np = jnp.load(f'~/unembeddings/{args.model_name.split("/")[-1].lower()}/clean_unembeddings.npy')
    inputs_np = inputs_np * jnp.sqrt(inputs_np.shape[0] / inputs_np.shape[1]) # set the norms to be close to 1
    print(f"Loaded data with shape: {inputs_np.shape}")
    embed_dim = inputs_np.shape[1]
    
    # Configure training
    hidden_dim = 49152
    k = 5
    num_epochs = 1000
    batch_size = 4096
    checkpoint_every = 100  # Save a checkpoint every 5 epochs
    
    # Auto-select learning rate using 1 / sqrt(d) scaling law
    scale = hidden_dim / (2 ** 14)
    learning_rate = 2e-4 / scale ** 0.5
    
    if args.resume is None:
        print(f"Initializing model with embed_dim={embed_dim}, hidden_dim={hidden_dim}, k={k}")
        
        # Initialize model
        key = jax.random.key(42)
        model = SparseAutoencoder.init(key, embed_dim, hidden_dim, k, bias_init=0.0)
        
        start_epoch = 0
        trained_steps = 0
    else:
        # Resume from checkpoint
        print(f"Resuming training from checkpoint: {args.resume}")
        
        # Load the checkpoint
        model, _, trained_steps = load_checkpoint(args.resume)
        
        # Calculate which epoch we're starting from
        steps_per_epoch = inputs_np.shape[0] // batch_size
        start_epoch = trained_steps // steps_per_epoch
        print(f"Resuming from epoch {start_epoch}, step {trained_steps}")
    
    print(f"Auto-selected learning rate: {learning_rate:.2e}")
    print(f"Training for {num_epochs} epochs with batch size {batch_size}")
    
    
    # Create custom train function with checkpointing
    def train_with_checkpoints(model, inputs, num_epochs, start_epoch=0, trained_steps=0):
        """Train with periodic checkpointing."""
        # Setup optimizer
        optimizer = optax.chain(
            optax.clip_by_global_norm(1.0),
            optax.adam(learning_rate=learning_rate)
        )
        opt_state = optimizer.init(model.params)
        
        # Calculate k_aux and other parameters
        k_aux = embed_dim // 2
        inactive_threshold = 10_000
        
        # Setup training loop (simplified version of the train function)
        # In a real implementation, you would adapt the full train function
        from sae_jax import InactiveLatentTracker, jitted_train_step, data_generator
        
        # Setup inactive latent tracker
        inactive_latent_tracker = InactiveLatentTracker.create(hidden_dim, inactive_threshold)
        
        nan_loss_count = 0
        inf_loss_count = 0
        
        params = model.params
        step = trained_steps
        training_start_time = time.time()
        
        # Training loop with checkpoints
        for epoch in range(start_epoch, num_epochs):
            epoch_start_time = time.time()
            
            for i, batch in enumerate(data_generator(inputs, batch_size)):
                top_k_inactive = inactive_latent_tracker.get_top_k_inactive_latents(k_aux)
                
                # Perform training step
                params, opt_state, loss, metrics = jitted_train_step(
                    params, 
                    opt_state, 
                    batch, 
                    top_k_inactive, 
                    model.k, 
                    model.embed_dim, 
                    model.hidden_dim, 
                    optimizer
                )
                
                # Update inactive latent tracking
                previously_inactive_mask = inactive_latent_tracker.get_inactive_latents_mask()
                inactive_latent_tracker = inactive_latent_tracker.update(metrics["top_active_mask"])
                currently_inactive_mask = inactive_latent_tracker.get_inactive_latents_mask()
                
                # Check for NaN or Inf in losses
                if jnp.isnan(metrics["Total Loss"]):
                    nan_loss_count += 1
                if jnp.isinf(metrics["Total Loss"]):
                    inf_loss_count += 1
                
                step += 1
                
                # Log every 100 steps
                if step % 100 == 0:
                    print(f"Epoch {epoch+1}/{num_epochs}, Step {step}: Loss = {metrics['Total Loss']:.6f}")
            
            # End of epoch
            epoch_time = time.time() - epoch_start_time
            print(f"Epoch {epoch+1} completed in {epoch_time:.2f} seconds")
            
            # Save checkpoint at regular intervals
            if (epoch + 1) % checkpoint_every == 0 or epoch + 1 == num_epochs:
                # Create a new model with updated parameters
                checkpoint_model = SparseAutoencoder(
                    params=params,
                    k=model.k,
                    embed_dim=model.embed_dim,
                    hidden_dim=model.hidden_dim
                )
                
                # Save checkpoint
                checkpoint_path = os.path.join(output_dir, f"sae_checkpoint_epoch{epoch+1}.pkl")
                save_checkpoint(checkpoint_model, opt_state, step, checkpoint_path[:-4])
        
        # Create the final model with trained parameters
        trained_model = SparseAutoencoder(
            params=params,
            k=model.k,
            embed_dim=model.embed_dim,
            hidden_dim=model.hidden_dim
        )
        
        total_time = time.time() - training_start_time
        print(f"Training completed in {total_time:.2f} seconds")
        
        # Save metadata
        metadata = {
            'training_time': total_time,
            'final_loss': float(metrics['Total Loss']),
            'epochs': num_epochs,
            'batch_size': batch_size,
            'learning_rate': learning_rate,
            'inactive_threshold': inactive_threshold,
            'embed_dim': embed_dim,
            'hidden_dim': hidden_dim,
            'k': k,
            'nan_loss_count': nan_loss_count,
            'inf_loss_count': inf_loss_count
        }
        
        metadata_path = os.path.join(output_dir, "training_metadata.pkl")
        save_metadata(metadata, metadata_path)
        
        return trained_model
    
    # Train the model with checkpointing
    trained_model = train_with_checkpoints(
        model, 
        inputs_np, 
        num_epochs, 
        start_epoch, 
        trained_steps
    )
    
    # Save the final model
    model_path = os.path.join(output_dir, f"k{k}_final_sae_model.pkl")
    save_model(trained_model, model_path)
    
    print("All processing complete!")

if __name__ == "__main__":
    import argparse
    import jax
    print("Devices:", jax.devices())
    import jax.numpy as jnp
    import json
    from transformers import AutoTokenizer
    from jax.experimental import mesh_utils
    from jax.sharding import PositionalSharding

    parser = argparse.ArgumentParser(description="Train a Sparse Autoencoder")
    parser.add_argument("--resume", type=str, help="Path to checkpoint to resume from", default=None)
    parser.add_argument("--mode", type=str, choices=["train", "preprocess"], 
                       help="Training mode: 'train' for training the model, 'preprocess' for data preprocessing", 
                       default="train")
    parser.add_argument("--model_name", type=str, 
                       help="Name of the model to use (e.g., 'google/gemma-7b', 'meta-llama/Llama-2-7b')", 
                       default="google/gemma-7b")
    args = parser.parse_args()
    
    # Clean model name by removing organization prefix
    clean_model_name = args.model_name.split('/')[-1].lower()
    

    if args.mode == "preprocess":
        os.makedirs(f'~/unembeddings/{clean_model_name}', exist_ok=True)
        # Preprocessing code
        num_devices = len(jax.devices())
        sharding = PositionalSharding(mesh_utils.create_device_mesh((num_devices,1)))
        import numpy as np
        gemma_model = AutoModelForCausalLM.from_pretrained(args.model_name, cache_dir = "~/gemma_cache")

        g = jax.device_put(jnp.array(gemma_model.get_output_embeddings().weight.detach().numpy()), sharding)

        # whiten reindexed
        g = g - g.mean(axis=0)
        u, s, vt = jnp.linalg.svd(g, full_matrices=False)
        g = u @ vt

        # Create directory if it doesn't exist
        os.makedirs(f'~/unembeddings/{clean_model_name}', exist_ok=True)
        jnp.save(f'~/unembeddings/{clean_model_name}/clean_unembeddings.npy', g)
    else:
        main(args)
